"""
Phase 3: Mechanisms & Robustness — Sectoral Savings Decomposition
=================================================================
Deeper analysis beyond phase 2 baselines:
  (a) Cointegration tests (Kao test on savings relationships)
  (b) Bootstrap standard errors for headline results
  (c) Placebo / permutation test
  (d) Leave-one-out country analysis
  (e) Regional jackknife
  (f) Consumption decomposition: Z -> private/govt consumption
  (g) Z x kaopen interactions
  (h) Eurozone subsample
  (i) Structural break: pre/post GFC
  (j) Aging speed: dZ -> savings
  (k) Investment decomposition (gross vs fixed)

Output: output/tables/ (markdown files)
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/sectoral_savings")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"
TABLES_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

# Eurozone 19 with staggered join dates
EZ_JOIN = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}

REGIONS = {
    'East Asia & Pacific': ['AUS', 'BRN', 'CHN', 'FJI', 'HKG', 'IDN', 'JPN', 'KHM',
                            'KOR', 'LAO', 'MAC', 'MMR', 'MNG', 'MYS', 'NZL', 'PHL',
                            'PNG', 'SGP', 'THA', 'TLS', 'TON', 'VNM', 'VUT', 'WSM'],
    'Europe & Central Asia': ['ALB', 'ARM', 'AUT', 'AZE', 'BEL', 'BGR', 'BIH', 'BLR',
                              'CHE', 'CYP', 'CZE', 'DEU', 'DNK', 'ESP', 'EST', 'FIN',
                              'FRA', 'GBR', 'GEO', 'GRC', 'HRV', 'HUN', 'IRL', 'ISL',
                              'ITA', 'KAZ', 'KGZ', 'LTU', 'LUX', 'LVA', 'MDA', 'MKD',
                              'MNE', 'NLD', 'NOR', 'POL', 'PRT', 'ROU', 'RUS', 'SRB',
                              'SVK', 'SVN', 'SWE', 'TJK', 'TKM', 'TUR', 'UKR', 'UZB'],
    'Latin America': ['ARG', 'BLZ', 'BOL', 'BRA', 'CHL', 'COL', 'CRI', 'CUB', 'DOM',
                      'ECU', 'GTM', 'GUY', 'HND', 'HTI', 'JAM', 'MEX', 'NIC', 'PAN',
                      'PER', 'PRY', 'SLV', 'SUR', 'TTO', 'URY', 'VEN'],
    'Middle East & N. Africa': ['ARE', 'BHR', 'DZA', 'EGY', 'IRN', 'IRQ', 'ISR', 'JOR',
                                 'KWT', 'LBN', 'LBY', 'MAR', 'OMN', 'QAT', 'SAU', 'TUN',
                                 'YEM'],
    'South Asia': ['AFG', 'BGD', 'BTN', 'IND', 'LKA', 'MDV', 'NPL', 'PAK'],
    'Sub-Saharan Africa': ['AGO', 'BDI', 'BEN', 'BFA', 'BWA', 'CAF', 'CIV', 'CMR',
                           'COD', 'COG', 'COM', 'CPV', 'DJI', 'ERI', 'ETH', 'GAB',
                           'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'KEN', 'LBR', 'LSO',
                           'MDG', 'MLI', 'MOZ', 'MRT', 'MUS', 'MWI', 'NAM', 'NER',
                           'NGA', 'RWA', 'SDN', 'SEN', 'SLE', 'SOM', 'SSD', 'STP',
                           'SWZ', 'SYC', 'TCD', 'TGO', 'TZA', 'UGA', 'ZAF', 'ZMB',
                           'ZWE'],
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label):
    """Run PanelGLS, return results dict or None."""
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        print(f"  [{label}] {dep_var} missing")
        return None
    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        print(f"  [{label}] Insufficient obs ({len(sub)})")
        return None
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)
    print(f"  [{label}] N={gls.n_obs}, R2={gls.r_squared:.4f}")
    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
    return results


def build_table(results, key_vars, notes, filename, title):
    """Write markdown table to output/tables/."""
    if not results:
        return
    md = [f"# {title}\n"]
    md.append("| Model | Dep Var | N | Countries | R2 |")
    md.append("|---|---|---|---|---|")
    for r in results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} |")
    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")
    md.append(f"\n*{notes}*")
    out = TABLES_DIR / filename
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


def main():
    print("=" * 70)
    print("PHASE 3: Mechanisms & Robustness — Sectoral Savings")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "sectoral_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = [c for c in ['rgdp_growth', 'kaopen', 'nfa_gdp_lag'] if c in df.columns]

    # ═══════════════════════════════════════════════════════════════════
    # PART A: COINTEGRATION TESTS (Kao-style)
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART A: Cointegration Tests (Kao-style ADF on residuals)")
    print("=" * 70)

    kao_results = []
    for dep in ['gross_national_savings_gdp', 'private_saving_gdp', 'govt_saving_gdp']:
        sub = df.dropna(subset=[dep] + demo_vars + controls).copy()
        if len(sub) < 100:
            continue
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[demo_vars + controls].values,
                sub['iso3'].values, sub['year'].values)
        resid = gls.resid

        # Panel ADF: run ADF on each country's residual series, average t-stats
        sub['resid'] = resid
        adf_tstats = []
        for iso3 in sub['iso3'].unique():
            csub = sub[sub['iso3'] == iso3].sort_values('year')
            r = csub['resid'].values
            if len(r) < 8:
                continue
            dr = np.diff(r)
            r_lag = r[:-1]
            if np.std(r_lag) < 1e-10:
                continue
            slope, intercept, _, _, se = stats.linregress(r_lag, dr)
            if se > 0:
                adf_tstats.append(slope / se)

        if adf_tstats:
            mean_t = np.mean(adf_tstats)
            n_panels = len(adf_tstats)
            # Kao test statistic: mean_t * sqrt(N) (approximately)
            kao_stat = mean_t * np.sqrt(n_panels)
            kao_p = 2 * stats.norm.cdf(kao_stat)  # one-sided: reject if < 0
            kao_results.append({
                'dep_var': dep, 'n_panels': n_panels,
                'mean_adf_t': mean_t, 'kao_stat': kao_stat, 'kao_p': kao_p,
            })
            reject = "REJECT (cointegrated)" if kao_p < 0.05 else "fail to reject"
            print(f"  {dep}: mean ADF t = {mean_t:.3f}, Kao stat = {kao_stat:.3f}, "
                  f"p = {kao_p:.4f} -> {reject}")

    if kao_results:
        md = ["# Cointegration Tests (Kao-style)\n"]
        md.append("| Dep Var | N Panels | Mean ADF t | Kao Stat | p-value | Result |")
        md.append("|---|---|---|---|---|---|")
        for r in kao_results:
            result = "Cointegrated" if r['kao_p'] < 0.05 else "Not cointegrated"
            md.append(f"| {r['dep_var']} | {r['n_panels']} | {r['mean_adf_t']:.3f} "
                      f"| {r['kao_stat']:.3f} | {r['kao_p']:.4f} | {result} |")
        md.append("\n*Kao (1999) panel cointegration test. H0: no cointegration. "
                  "ADF regressions on residuals from Z -> savings.*")
        (TABLES_DIR / "cointegration.md").write_text('\n'.join(md))
        print(f"  Saved: {TABLES_DIR / 'cointegration.md'}")

    # ═══════════════════════════════════════════════════════════════════
    # PART B: BOOTSTRAP STANDARD ERRORS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART B: Bootstrap Standard Errors (headline results)")
    print("=" * 70)

    B = 500
    boot_results = []

    for dep, label in [('gross_national_savings_gdp', 'national savings'),
                       ('private_saving_gdp', 'private saving'),
                       ('govt_saving_gdp', 'govt saving')]:
        sub = df.dropna(subset=[dep] + demo_vars + controls).copy()
        if len(sub) < 100:
            continue

        # Point estimate
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[demo_vars + controls].values,
                sub['iso3'].values, sub['year'].values)
        z1_point = gls.beta[0]
        z1_se_ols = gls.se[0]

        # Country-cluster bootstrap
        countries = sub['iso3'].unique()
        boot_z1 = []
        for b in range(B):
            boot_countries = np.random.choice(countries, size=len(countries), replace=True)
            boot_dfs = []
            for j, c in enumerate(boot_countries):
                cdf = sub[sub['iso3'] == c].copy()
                cdf['iso3'] = f"{c}_{j}"  # unique ids for resampled countries
                boot_dfs.append(cdf)
            bdf = pd.concat(boot_dfs, ignore_index=True)
            try:
                bgls = PanelGLS()
                bgls.fit(bdf[dep].values, bdf[demo_vars + controls].values,
                         bdf['iso3'].values, bdf['year'].values)
                boot_z1.append(bgls.beta[0])
            except Exception:
                pass

        boot_se = np.std(boot_z1) if boot_z1 else np.nan
        boot_ci_lo = np.percentile(boot_z1, 2.5) if boot_z1 else np.nan
        boot_ci_hi = np.percentile(boot_z1, 97.5) if boot_z1 else np.nan
        boot_p = 2 * min(np.mean(np.array(boot_z1) > 0),
                         np.mean(np.array(boot_z1) < 0)) if boot_z1 else np.nan

        boot_results.append({
            'dep_var': label, 'z1_point': z1_point,
            'z1_se_ols': z1_se_ols, 'z1_se_boot': boot_se,
            'boot_ci_lo': boot_ci_lo, 'boot_ci_hi': boot_ci_hi,
            'boot_p': boot_p, 'n_boot': len(boot_z1),
        })
        print(f"  {label}: Z1 = {z1_point:.1f}, OLS SE = {z1_se_ols:.1f}, "
              f"Boot SE = {boot_se:.1f}, 95% CI = [{boot_ci_lo:.1f}, {boot_ci_hi:.1f}]")

    if boot_results:
        md = ["# Bootstrap Standard Errors\n"]
        md.append("| Dep Var | Z1 Point | OLS SE | Boot SE | 95% CI | Boot p |")
        md.append("|---|---|---|---|---|---|")
        for r in boot_results:
            md.append(f"| {r['dep_var']} | {r['z1_point']:.1f} | {r['z1_se_ols']:.1f} "
                      f"| {r['z1_se_boot']:.1f} | [{r['boot_ci_lo']:.1f}, {r['boot_ci_hi']:.1f}] "
                      f"| {r['boot_p']:.4f} |")
        md.append(f"\n*{B} country-cluster bootstrap replications.*")
        (TABLES_DIR / "bootstrap.md").write_text('\n'.join(md))
        print(f"  Saved: {TABLES_DIR / 'bootstrap.md'}")

    # ═══════════════════════════════════════════════════════════════════
    # PART C: PLACEBO / PERMUTATION TEST
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART C: Placebo / Permutation Test")
    print("=" * 70)

    N_PERM = 500
    perm_results = []

    for dep in ['gross_national_savings_gdp', 'private_saving_gdp']:
        sub = df.dropna(subset=[dep] + demo_vars + controls).copy()
        if len(sub) < 100:
            continue

        # True estimate
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[demo_vars + controls].values,
                sub['iso3'].values, sub['year'].values)
        true_z1 = gls.beta[0]

        # Permute Z variables across countries within year
        perm_z1 = []
        for _ in range(N_PERM):
            psub = sub.copy()
            for yr in psub['year'].unique():
                mask = psub['year'] == yr
                n_yr = mask.sum()
                perm_idx = np.random.permutation(n_yr)
                for zv in demo_vars:
                    vals = psub.loc[mask, zv].values
                    psub.loc[mask, zv] = vals[perm_idx]
            try:
                pgls = PanelGLS()
                pgls.fit(psub[dep].values, psub[demo_vars + controls].values,
                         psub['iso3'].values, psub['year'].values)
                perm_z1.append(pgls.beta[0])
            except Exception:
                pass

        perm_p = np.mean(np.abs(perm_z1) >= np.abs(true_z1)) if perm_z1 else np.nan
        perm_results.append({
            'dep_var': dep, 'true_z1': true_z1,
            'perm_mean': np.mean(perm_z1) if perm_z1 else np.nan,
            'perm_sd': np.std(perm_z1) if perm_z1 else np.nan,
            'perm_p': perm_p, 'n_perm': len(perm_z1),
        })
        print(f"  {dep}: true Z1 = {true_z1:.1f}, perm mean = {np.mean(perm_z1):.1f}, "
              f"perm p = {perm_p:.4f}")

    if perm_results:
        md = ["# Placebo / Permutation Test\n"]
        md.append("| Dep Var | True Z1 | Perm Mean | Perm SD | Perm p-value |")
        md.append("|---|---|---|---|---|")
        for r in perm_results:
            md.append(f"| {r['dep_var']} | {r['true_z1']:.1f} | {r['perm_mean']:.1f} "
                      f"| {r['perm_sd']:.1f} | {r['perm_p']:.4f} |")
        md.append(f"\n*{N_PERM} permutations. Demographics permuted across countries within year.*")
        (TABLES_DIR / "placebo.md").write_text('\n'.join(md))
        print(f"  Saved: {TABLES_DIR / 'placebo.md'}")

    # ═══════════════════════════════════════════════════════════════════
    # PART D: LEAVE-ONE-OUT COUNTRY ANALYSIS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART D: Leave-One-Out Country Analysis")
    print("=" * 70)

    loo_results = []
    dep = 'gross_national_savings_gdp'
    sub = df.dropna(subset=[dep] + demo_vars + controls).copy()
    countries = sub['iso3'].unique()

    # Full-sample estimate
    gls = PanelGLS()
    gls.fit(sub[dep].values, sub[demo_vars + controls].values,
            sub['iso3'].values, sub['year'].values)
    full_z1 = gls.beta[0]
    print(f"  Full sample Z1 = {full_z1:.2f}")

    for iso3 in countries:
        loo = sub[sub['iso3'] != iso3].copy()
        try:
            lgls = PanelGLS()
            lgls.fit(loo[dep].values, loo[demo_vars + controls].values,
                     loo['iso3'].values, loo['year'].values)
            loo_results.append({
                'dropped': iso3, 'z1': lgls.beta[0],
                'change_pct': (lgls.beta[0] / full_z1 - 1) * 100,
            })
        except Exception:
            pass

    if loo_results:
        loo_df = pd.DataFrame(loo_results)
        loo_df = loo_df.sort_values('change_pct')
        min_z1 = loo_df.iloc[0]
        max_z1 = loo_df.iloc[-1]
        range_pct = max_z1['change_pct'] - min_z1['change_pct']

        print(f"  Z1 range: [{loo_df['z1'].min():.1f}, {loo_df['z1'].max():.1f}]")
        print(f"  Most influential (reduces): {min_z1['dropped']} -> Z1 = {min_z1['z1']:.1f} ({min_z1['change_pct']:+.1f}%)")
        print(f"  Most influential (increases): {max_z1['dropped']} -> Z1 = {max_z1['z1']:.1f} ({max_z1['change_pct']:+.1f}%)")

        # Top 10 most influential
        top_reduce = loo_df.head(5)
        top_increase = loo_df.tail(5)

        md = ["# Leave-One-Out Country Analysis\n"]
        md.append(f"Full sample: Z1 = {full_z1:.2f} (N = {gls.n_obs:,})\n")
        md.append(f"LOO range: [{loo_df['z1'].min():.1f}, {loo_df['z1'].max():.1f}] "
                  f"(spread = {range_pct:.1f}% of full estimate)\n")
        md.append("## Most Influential Countries (Dropping Reduces Z1)\n")
        md.append("| Country | Z1 | Change (%) |")
        md.append("|---|---|---|")
        for _, r in top_reduce.iterrows():
            md.append(f"| {r['dropped']} | {r['z1']:.2f} | {r['change_pct']:+.1f}% |")
        md.append("\n## Most Influential Countries (Dropping Increases Z1)\n")
        md.append("| Country | Z1 | Change (%) |")
        md.append("|---|---|---|")
        for _, r in top_increase.iterrows():
            md.append(f"| {r['dropped']} | {r['z1']:.2f} | {r['change_pct']:+.1f}% |")
        md.append(f"\n*LOO on {dep}. {len(loo_results)} countries.*")
        (TABLES_DIR / "leave_one_out.md").write_text('\n'.join(md))
        print(f"  Saved: {TABLES_DIR / 'leave_one_out.md'}")

    # ═══════════════════════════════════════════════════════════════════
    # PART E: REGIONAL JACKKNIFE
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART E: Regional Jackknife")
    print("=" * 70)

    rj_results = []
    dep = 'gross_national_savings_gdp'
    sub = df.dropna(subset=[dep] + demo_vars + controls).copy()

    for region, iso3s in REGIONS.items():
        rj = sub[~sub['iso3'].isin(iso3s)].copy()
        if len(rj) < 100:
            continue
        try:
            rgls = PanelGLS()
            rgls.fit(rj[dep].values, rj[demo_vars + controls].values,
                     rj['iso3'].values, rj['year'].values)
            rj_results.append({
                'region_dropped': region,
                'z1': rgls.beta[0], 'se': rgls.se[0], 'p': rgls.pvalues[0],
                'n_obs': rgls.n_obs, 'n_countries': rgls.n_countries,
            })
            print(f"  Drop {region}: Z1 = {rgls.beta[0]:.1f} ({stars(rgls.pvalues[0])}), "
                  f"N = {rgls.n_obs}")
        except Exception:
            pass

    if rj_results:
        md = ["# Regional Jackknife\n"]
        md.append(f"Full sample: Z1 = {full_z1:.2f}\n")
        md.append("| Region Dropped | Z1 | SE | p-value | Sig | N | Countries |")
        md.append("|---|---|---|---|---|---|---|")
        for r in rj_results:
            md.append(f"| {r['region_dropped']} | {r['z1']:.2f} | {r['se']:.2f} "
                      f"| {r['p']:.4f} | {stars(r['p'])} | {r['n_obs']:,} | {r['n_countries']} |")
        md.append(f"\n*Regional jackknife on {dep}.*")
        (TABLES_DIR / "regional_jackknife.md").write_text('\n'.join(md))
        print(f"  Saved: {TABLES_DIR / 'regional_jackknife.md'}")

    # ═══════════════════════════════════════════════════════════════════
    # PART F: CONSUMPTION DECOMPOSITION
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART F: Consumption Decomposition")
    print("=" * 70)

    f_results = []

    for dep, label in [
        ('private_consumption_gdp', 'Z -> private consumption'),
        ('govt_consumption_gdp', 'Z -> govt consumption'),
        ('adj_net_national_savings', 'Z -> adj net savings'),
    ]:
        r = run_model(df, dep, demo_vars + controls, label)
        if r: f_results.append(r)

    build_table(f_results, demo_vars,
                "Consumption decomposition: demographics -> consumption components",
                "consumption_decomposition.md",
                "Consumption Decomposition: Z -> Consumption Components")

    # ═══════════════════════════════════════════════════════════════════
    # PART G: Z x KAOPEN INTERACTIONS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART G: Z x KAOPEN Interactions")
    print("=" * 70)

    g_results = []
    interaction_vars = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']

    for dep, label in [
        ('gross_national_savings_gdp', 'savings x kaopen'),
        ('private_saving_gdp', 'private saving x kaopen'),
        ('govt_saving_gdp', 'govt saving x kaopen'),
    ]:
        r = run_model(df, dep, demo_vars + controls + interaction_vars, label)
        if r: g_results.append(r)

    build_table(g_results, demo_vars + interaction_vars,
                "Z x KAOPEN interactions: does capital openness moderate savings response?",
                "kaopen_interactions.md",
                "Capital Openness Interactions: Z x KAOPEN -> Savings")

    # ═══════════════════════════════════════════════════════════════════
    # PART H: EUROZONE SUBSAMPLE
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART H: Eurozone Subsample")
    print("=" * 70)

    # Time-varying EZ membership
    df['ez_member'] = 0
    for iso3, join_yr in EZ_JOIN.items():
        mask = (df['iso3'] == iso3) & (df['year'] >= join_yr)
        df.loc[mask, 'ez_member'] = 1

    ez = df[df['ez_member'] == 1].copy()
    non_ez = df[df['ez_member'] == 0].copy()

    h_results = []
    for dep, label_prefix in [
        ('gross_national_savings_gdp', 'savings'),
        ('govt_saving_gdp', 'govt_saving'),
        ('private_saving_gdp', 'private_saving'),
    ]:
        r = run_model(ez, dep, demo_vars + controls, f"EZ: Z -> {label_prefix}")
        if r: h_results.append(r)
        r = run_model(non_ez, dep, demo_vars + controls, f"non-EZ: Z -> {label_prefix}")
        if r: h_results.append(r)

    build_table(h_results, demo_vars,
                "Eurozone subsample (time-varying membership). Fixed exchange rate -> different savings adjustment?",
                "eurozone.md",
                "Eurozone vs Non-Eurozone Savings Decomposition")

    # ═══════════════════════════════════════════════════════════════════
    # PART I: STRUCTURAL BREAK — PRE/POST GFC
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART I: Structural Break — Pre/Post GFC")
    print("=" * 70)

    pre_gfc = df[df['year'] <= 2007].copy()
    post_gfc = df[df['year'] >= 2008].copy()

    i_results = []
    for dep, label_prefix in [
        ('gross_national_savings_gdp', 'savings'),
        ('private_saving_gdp', 'private_saving'),
        ('govt_saving_gdp', 'govt_saving'),
    ]:
        r = run_model(pre_gfc, dep, demo_vars + controls, f"Pre-GFC Z -> {label_prefix}")
        if r: i_results.append(r)
        r = run_model(post_gfc, dep, demo_vars + controls, f"Post-GFC Z -> {label_prefix}")
        if r: i_results.append(r)

    build_table(i_results, demo_vars,
                "Structural break at 2007/2008. Pre-GFC: 1990-2007. Post-GFC: 2008-2024.",
                "structural_break.md",
                "Structural Break: Pre/Post GFC Savings Decomposition")

    # Chow test for savings
    print("\n  Chow test for structural break (savings):")
    dep = 'gross_national_savings_gdp'
    regs = demo_vars + controls
    for period_label, period_df in [('Full', df), ('Pre-GFC', pre_gfc), ('Post-GFC', post_gfc)]:
        sub = period_df.dropna(subset=[dep] + regs)
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[regs].values, sub['iso3'].values, sub['year'].values)
        rss = np.sum(gls.resid ** 2)
        print(f"  {period_label}: RSS = {rss:.2f}, N = {gls.n_obs}")

    # ═══════════════════════════════════════════════════════════════════
    # PART J: AGING SPEED — dZ -> SAVINGS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART J: Aging Speed — dZ -> Savings")
    print("=" * 70)

    d_demo_vars = ['d_Z_1', 'd_Z_2', 'd_Z_3']
    j_results = []

    for dep, label in [
        ('gross_national_savings_gdp', 'dZ -> national savings'),
        ('private_saving_gdp', 'dZ -> private saving'),
        ('govt_saving_gdp', 'dZ -> govt saving'),
        ('ca_gdp', 'dZ -> CA/GDP'),
    ]:
        r = run_model(df, dep, d_demo_vars + controls, label)
        if r: j_results.append(r)

    # Also: level + change together
    for dep, label in [
        ('gross_national_savings_gdp', 'Z + dZ -> national savings'),
        ('private_saving_gdp', 'Z + dZ -> private saving'),
    ]:
        r = run_model(df, dep, demo_vars + d_demo_vars + controls, label)
        if r: j_results.append(r)

    build_table(j_results, demo_vars + d_demo_vars,
                "Aging speed: does the rate of demographic change matter? "
                "dZ = year-on-year change in Z.",
                "aging_speed.md",
                "Aging Speed: Rate of Demographic Change -> Savings")

    # ═══════════════════════════════════════════════════════════════════
    # PART K: INVESTMENT DECOMPOSITION
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART K: Investment Decomposition")
    print("=" * 70)

    k_results = []
    for dep, label in [
        ('gross_investment_gdp', 'Z -> gross investment'),
        ('gross_fixed_investment_gdp', 'Z -> fixed investment'),
    ]:
        r = run_model(df, dep, demo_vars + controls, label)
        if r: k_results.append(r)

    # Inventory investment residual
    if ('gross_investment_gdp' in df.columns and
        'gross_fixed_investment_gdp' in df.columns):
        df['inventory_investment_gdp'] = (df['gross_investment_gdp'] -
                                           df['gross_fixed_investment_gdp'])
        r = run_model(df, 'inventory_investment_gdp', demo_vars + controls,
                      'Z -> inventory investment')
        if r: k_results.append(r)

    build_table(k_results, demo_vars,
                "Investment decomposition: gross, fixed, and inventory (residual).",
                "investment_decomposition.md",
                "Investment Decomposition: Z -> Investment Components")

    # ═══════════════════════════════════════════════════════════════════
    # PART L: FISCAL BALANCE INTERACTION
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART L: Fiscal Balance Interaction (Z x fiscal_bal)")
    print("=" * 70)

    if 'fiscal_bal_gdp' in df.columns:
        df['Z_1_x_fiscal'] = df['Z_1'] * df['fiscal_bal_gdp']
        df['Z_2_x_fiscal'] = df['Z_2'] * df['fiscal_bal_gdp']
        df['Z_3_x_fiscal'] = df['Z_3'] * df['fiscal_bal_gdp']
        fiscal_int_vars = ['Z_1_x_fiscal', 'Z_2_x_fiscal', 'Z_3_x_fiscal']

        l_results = []
        for dep, label in [
            ('gross_national_savings_gdp', 'savings x fiscal'),
            ('private_saving_gdp', 'private saving x fiscal'),
        ]:
            r = run_model(df, dep, demo_vars + controls + ['fiscal_bal_gdp'] + fiscal_int_vars,
                          label)
            if r: l_results.append(r)

        build_table(l_results, demo_vars + fiscal_int_vars + ['fiscal_bal_gdp'],
                    "Z x fiscal balance interactions: does fiscal stance moderate savings response?",
                    "fiscal_interactions.md",
                    "Fiscal Interactions: Z x Fiscal Balance -> Savings")

    # ═══════════════════════════════════════════════════════════════════
    # PART M: TRADE OPENNESS INTERACTION
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART M: Trade Openness Interaction")
    print("=" * 70)

    trade_int_vars = ['Z_1_x_trade', 'Z_2_x_trade', 'Z_3_x_trade']
    if all(v in df.columns for v in trade_int_vars):
        m_results = []
        for dep, label in [
            ('gross_national_savings_gdp', 'savings x trade'),
            ('private_saving_gdp', 'private saving x trade'),
        ]:
            r = run_model(df, dep, demo_vars + controls + trade_int_vars, label)
            if r: m_results.append(r)

        build_table(m_results, demo_vars + trade_int_vars,
                    "Z x trade openness interactions.",
                    "trade_interactions.md",
                    "Trade Openness Interactions: Z x Trade -> Savings")

    # ═══════════════════════════════════════════════════════════════════
    # SUMMARY
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("Phase 3 complete. Output tables saved to:")
    print(f"  {TABLES_DIR}")
    print("=" * 70)


if __name__ == "__main__":
    main()
